

\section{增加一个新 Op }\label{ux589eux52a0ux4e00ux4e2aux65b0-op}

预备知识:

\begin{itemize}
\tightlist
\item
  对 C++ 有一定了解.
\item
  已经\href{tensorflow-zh/SOURCE/get_started/introduction.md\#source}{下载
  TensorFlow 源代码}并有能力编译它.
\end{itemize}

如果现有的库没有涵盖你想要的操作, 你可以自己定制一个. 为了使定制的 Op
能够兼容原有的库 , 你必须做以下工作:

\begin{itemize}
\tightlist
\item
  在一个 C++ 文件中注册新 Op. Op 的注册与实现是相互独立的.
  在其注册时描述了 Op 该如何执行. 例如, 注册 Op 时定义了 Op 的名字,
  并指定了它的输入和输出.
\item
  使用 C++ 实现 Op. 每一个实现称之为一个 ``kernel'', 可以存在多个
  kernel, 以适配不同的架构 (CPU, GPU 等)或不同的输入/输出类型.
\item
  创建一个 Python 包装器（wrapper）. 这个包装器是创建 Op 的公开 API.
  当注册 Op 时, 会自动生成一个默认 默认的包装器.
  既可以直接使用默认包装器, 也可以添加一个新的包装器.
\item
  (可选) 写一个函数计算 Op 的梯度.
\item
  (可选) 写一个函数, 描述 Op 的输入和输出 shape. 该函数能够允许从 Op
  推断 shape.
\item
  测试 Op, 通常使用
  Pyhton。如果你定义了梯度，你可以使用Python的\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/kernel_tests/gradient_checker.py}{GradientChecker}来测试它。
\end{itemize}

\subsection{内容}\label{ux5185ux5bb9}

\subsubsection{\texorpdfstring{\protect\hyperlink{AUTOGENERATED-adding-a-new-op}{增加一个新
Op}}{增加一个新 Op}}\label{ux589eux52a0ux4e00ux4e2aux65b0-op-1}

\begin{itemize}
\tightlist
\item
  \protect\hyperlink{defineux5finterface}{定义 Op 的接口}
\item
  \protect\hyperlink{AUTOGENERATED-implement-the-kernel-for-the-op}{为
  Op 实现 kernel}
\item
  \protect\hyperlink{AUTOGENERATED-generate-the-client-wrapper}{生成客户端包装器}
\item
  \protect\hyperlink{AUTOGENERATED-the-python-op-wrapper}{Python Op
  包装器}
\item
  \protect\hyperlink{AUTOGENERATED-the-c---op-wrapper}{C++ Op 包装器}
\item
  \protect\hyperlink{AUTOGENERATED-verify-it-works}{检查 Op
  能否正常工作}
\item
  \protect\hyperlink{Validation}{验证条件}
\item
  \protect\hyperlink{AUTOGENERATED-op-registration}{Op 注册}
\item
  \protect\hyperlink{Attrs}{属性}
\item
  \protect\hyperlink{AUTOGENERATED-attr-types}{属性类型}
\item
  \protect\hyperlink{Polymorphism}{多态}
\item
  \protect\hyperlink{AUTOGENERATED-inputs-and-outputs}{输入和输出}
\item
  \protect\hyperlink{AUTOGENERATED-backwards-compatibility}{向后兼容性}
\item
  \protect\hyperlink{mult-archs}{GPU 支持}
\item
  \protect\hyperlink{AUTOGENERATED-implement-the-gradient-in-python}{使用
  Python 实现梯度}
\item
  \protect\hyperlink{AUTOGENERATED-implement-a-shape-function-in-python}{使用
  Python 实现 shape 函数}
\end{itemize}

\subsection{定义 Op 的接口 }\label{ux5b9aux4e49-op-ux7684ux63a5ux53e3}

向 TensorFlow 系统注册来定义 Op 的接口. 在注册时, 指定 Op 的名称,
它的输入(类型和名称) 和输出(类型和名称), 和所需要任何
\protect\hyperlink{Attrs}{属性}的文档说明.

为了让你有直观的认识, 创建一个简单的 Op 作为例子. 该 Op 接受一个
\texttt{int32} 类型 tensor 作为 输入, 输出这个 tensor 的一个副本,
副本与原 tensor 唯一的区别在于第一个元素被置为 0. 创建 文件
\texttt{tensorflow/core/user\_ops/zero\_out.cc}, 并调用
\texttt{REGISTER\_OP} 宏来定义 Op 的接口.

\begin{verbatim}
 #include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32");
\end{verbatim}

\texttt{ZeroOut} Op 接受 32 位整型的 tensor \texttt{to\_zero} 作为输入,
输出 32 位整型的 tensor \texttt{zeroed}.

\begin{quote}
命名的注意事项: Op 的名称必须是为唯一的, 并使用驼峰命名法. 以下划线
\texttt{\_} 开始的名称保留为内部使用.
\end{quote}

\subsection{为 Op 实现 kernel }\label{ux4e3a-op-ux5b9eux73b0-kernel}

在定义接口之后, 提供一个或多个 Op 的实现. 为这些 kernel
的每一个创建一个对应的类, 继承 \texttt{OpKernel}, 覆盖 \texttt{Compute}
方法. \texttt{Compute} 方法提供一个类型为 \texttt{OpKernelContext*}
的参数 \texttt{context}, 用于访问一些有用的信息, 例如输入和输出的
tensor.

将 kernel 添加到刚才创建的文件中, kernel 看起来和下面的代码类似:

\begin{verbatim}
 #include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
    // 获取输入 tensor.
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();
   // 创建一个输出 tensor.
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output = output_tensor->template flat<int32>();
    // 设置 tensor 除第一个之外的元素均设为 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output(i) = 0;
    }
    // 尽可能地保留第一个元素的值.
    if (N > 0) output(0) = input(0);
  }
};
\end{verbatim}

实现 kernel 后, 将其注册到 TensorFlow 系统中. 注册时, 可以指定该 kernel
运行时的多个约束 条件. 例如可以指定一个 kernel 在 CPU 上运行, 另一个在
GPU 上运行.

将下列代码加入到 \texttt{zero\_out.cc} 中, 注册 \texttt{ZeroOut} op:

\begin{verbatim}
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
\end{verbatim}

一旦\href{tensorflow-zh/SOURCE/get_started/os_setup.md\#create-pip}{创建和重新安装了
TensorFlow}, Tensorflow 系统可以在需要时引用和使用该 Op.

\subsection{生成客户端包装器
}\label{ux751fux6210ux5ba2ux6237ux7aefux5305ux88c5ux5668}

\subsubsection{Python Op 包装器 }\label{python-op-ux5305ux88c5ux5668}

当编译 TensorFlow 时, 所有放在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/user_ops/}{\texttt{tensorflow/core/user\_ops}}
目录下 的 Op 会自动在
\texttt{bazel-genfiles/tensorflow/python/ops/gen\_user\_ops.py} 文件
中生成 Python Op 包装器. 通过以下声明, 把那些 Op 引入到
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/user_ops/user_ops.py}{\texttt{tensorflow/python/user\_ops/user\_ops.py}}
中:

\begin{Shaded}
\begin{Highlighting}[]
\ImportTok{from} \NormalTok{tensorflow.python.ops.gen_user_ops }\ImportTok{import} \OperatorTok{*}
\end{Highlighting}
\end{Shaded}

你可以选择性将部分函数替换为自己的实现. 为此, 首先要隐藏自动生成的代码,
在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/BUILD}{\texttt{tensorflow/python/BUILD}}
文件中, 将其名字添加到 \texttt{"user\_ops"} 的 \texttt{hidden} 列表.

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{tf_gen_op_wrapper_py(}
    \NormalTok{name }\OperatorTok{=} \StringTok{"user_ops"}\NormalTok{,}
    \NormalTok{hidden }\OperatorTok{=} \NormalTok{[}
        \StringTok{"Fact"}\NormalTok{,}
    \NormalTok{],}
    \NormalTok{require_shape_functions }\OperatorTok{=} \VariableTok{False}\NormalTok{,}
\NormalTok{)}
\end{Highlighting}
\end{Shaded}

紧接着 \texttt{"Fact"} 列出自己的 Op. 然后, 在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/user_ops/user_ops.py}{\texttt{tensorflow/python/user\_ops/user\_ops.py}}
中添加你的替代实现函数. 通常, 替代实现函数也会调用自动生成函数来真正把
Op 添加 到图中. 被隐藏的自动生成函数位于 \texttt{gen\_user\_ops} 包中,
名称多了一个下划线前缀 (``\texttt{\_}''). 例如:

\begin{Shaded}
\begin{Highlighting}[]
\KeywordTok{def} \NormalTok{my_fact():}
    \CommentTok{"""覆盖一个 Op 自动生成代码的示例."""}
    \ControlFlowTok{return} \NormalTok{gen_user_ops._fact()}
\end{Highlighting}
\end{Shaded}

\subsubsection{C++ Op 包装器 }\label{c-op-ux5305ux88c5ux5668}

当编译 TensorFlow 时, 所有
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/user_ops/}{\texttt{tensorflow/core/user\_ops}}
文件夹 下的 Op 会自动创建 C++ Op 包装器. 例如,
\texttt{tensorflow/core/user\_ops/zero\_out.cc} 中的 Op 会自动在
\texttt{bazel-genfiles/tensorflow/cc/ops/user\_ops.\{h,cc\}}
中生成包装器.

\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/cc/ops/standard_ops.h}{\texttt{tensorflow/cc/ops/standard\_ops.h}}
通过下述申明, 导入用户自定义 Op 自动生成的包装器.

\begin{verbatim}
 #include "tensorflow/cc/ops/user_ops.h"
\end{verbatim}

\subsection{检查 Op 能否正常工作
}\label{ux68c0ux67e5-op-ux80fdux5426ux6b63ux5e38ux5de5ux4f5c}

验证已经成功实现 Op 的方式是编写测试程序. 创建文件
\texttt{tensorflow/python/kernel\_tests/zero\_out\_op\_test.py},
包含以下内容:

\begin{Shaded}
\begin{Highlighting}[]
\ImportTok{import} \NormalTok{tensorflow }\ImportTok{as} \NormalTok{tf}
\KeywordTok{class} \NormalTok{ZeroOutTest(tf.test.TestCase):}
  \KeywordTok{def} \NormalTok{testZeroOut(}\VariableTok{self}\NormalTok{):}
    \ControlFlowTok{with} \VariableTok{self}\NormalTok{.test_session():}
      \NormalTok{result }\OperatorTok{=} \NormalTok{tf.user_ops.zero_out([}\DecValTok{5}\NormalTok{, }\DecValTok{4}\NormalTok{, }\DecValTok{3}\NormalTok{, }\DecValTok{2}\NormalTok{, }\DecValTok{1}\NormalTok{])}
      \VariableTok{self}\NormalTok{.assertAllEqual(result.}\BuiltInTok{eval}\NormalTok{(), [}\DecValTok{5}\NormalTok{, }\DecValTok{0}\NormalTok{, }\DecValTok{0}\NormalTok{, }\DecValTok{0}\NormalTok{, }\DecValTok{0}\NormalTok{])}
\end{Highlighting}
\end{Shaded}

然后运行测试:

\begin{verbatim}
$ bazel test tensorflow/python:zero_out_op_test
\end{verbatim}

\subsection{验证条件 }\label{ux9a8cux8bc1ux6761ux4ef6}

上述示例假定 Op 能够应用在任何 shape 的 tensor 上. 如果只想应用到 vector
上 呢? 这意味需要在上述 OpKernel 实现中添加相关的检查.

\begin{verbatim}
  void Compute(OpKernelContext* context) override {
   // 获取输入 tensor
    const Tensor& input_tensor = context->input(0);
    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }
\end{verbatim}

OP\_REQUIRES 断言的输入是一个 vector, 如果不是 vector, 将设置
\texttt{InvalidArgument} 状态并返回.
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/lib/core/errors.h}{\texttt{OP\_REQUIRES}
宏} 有三个参数:

\begin{itemize}
\tightlist
\item
  \texttt{context}: 可以是一个 \texttt{OpKernelContext} 或
  \texttt{OpKernelConstruction} 指针 (参见
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_kernel.h}{\texttt{tensorflow/core/framework/op\_kernel.h}}),
  其 \texttt{SetStatus()} 方法将被使用到.
\item
  检查条件:
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/public/tensor_shape.h}{\texttt{tensorflow/core/public/tensor\_shape.h}}
  中有一些验证 tensor shape 的函数.
\item
  条件不满足时产生的错误: 错误用一个 \texttt{Status} 对象表示, 参见
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/public/status.h}{\texttt{tensorflow/core/public/status.h}}.
  \texttt{Status} 包含一个类型 (通常是 \texttt{InvalidArgument},
  但也可以是任何类型) 和一个消息. 构造 一个错误的函数位于
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/lib/core/errors.h}{\texttt{tensorflow/core/lib/core/errors.h}}
  中.
\end{itemize}

如果想要测试一个函数返回的 \texttt{Status} 对象是否是一个错误, 可以使用
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/lib/core/errors.h}{\texttt{OP\_REQUIRES\_OK}}.
这些宏如果检测到错误, 会直接跳出函数, 终止函数执行.

\subsection{Op 注册 }\label{op-ux6ce8ux518c}

\subsubsection{属性 }\label{ux5c5eux6027}

Op 可以有属性, 属性的值在 Op 添加到图中时被设置. 属性值用于配置 Op, 在
kernel 实现中, Op 注册的输入和输出类型中, 均可访问这些属性值.
尽可能地使用输入代替属性, 因为输入的灵活性更高, 例如可以在执行步骤中
中被更改, 可以使用 feed 等等. 属性可用于实现一些输入无法做到的事情,
例如影响 Op 签名 (即输入输出的数量和类型)
的配置或只读配置可以通过属性实现.

注册 Op 时可以用 \texttt{Attr} 方法指定属性的名称和类型,
以此来定义一个属性, 形式如下:

\begin{verbatim}
<name>: <attr-type-expr>
\end{verbatim}

\texttt{\textless{}name\textgreater{}} 必须以字母开头, 可以由数字, 字母,
下划线组成. \texttt{\textless{}attr-type-expr\textgreater{}}
是一个类型表达式, 形式\protect\hyperlink{attr-types}{如下}:

例如, 如果想要 \texttt{ZeroOut} Op 保存一个用户索引, 指示该 Op
不仅仅只有一个元素, 你可以注册 Op 如下:

\begin{verbatim}
REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");
\end{verbatim}

你的 kernel 可以在构造函数里, 通过 \texttt{context} 参数访问这个属性:

\begin{verbatim}
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) {
   // 获取欲保存的索引值
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // 检查 preserve_index 是否为正
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
}
 private:
  int preserve_index_;
};
\end{verbatim}

该值可以在 \texttt{Compute} 方法中被使用:

\begin{verbatim}
void Compute(OpKernelContext* context) override {
    // ...
   // 检查 preserve_index 范围是否合法
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));
    // 设置输出 tensor 所有的元素值为 0
   const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }
    // 保存请求的输入值
   output_flat(preserve_index_) = input(preserve_index_);
  }
\end{verbatim}

\begin{quote}
为了维持\protect\hyperlink{backwards-compatibility}{向后兼容性},
将一个属性添加到一个已有的 Op 时,
必须指定一个\protect\hyperlink{default-values-constraints}{默认值}:
\end{quote}

\begin{verbatim}
REGISTER_OP("ZeroOut")
     .Attr("preserve_index: int = 0")
     .Input("to_zero: int32")
     .Output("zeroed: int32");
\end{verbatim}

\subsubsection{属性类型 }\label{ux5c5eux6027ux7c7bux578b}

属性可以使用下面的类型:

\begin{itemize}
\tightlist
\item
  \texttt{string}: 任何二进制字节流 (UTF8 不是必须的).
\item
  \texttt{int}: 一个有型整数.
\item
  \texttt{float}: 一个浮点数.
\item
  \texttt{bool}: 真或假.
\item
  \texttt{type}:
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/types.cc\#DataTypeString}{\texttt{DataType}}
  非引用类型之一.
\item
  \texttt{shape}: 一个
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/tensor_shape.proto}{\texttt{TensorShapeProto}}.
\item
  \texttt{tensor}: 一个
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/tensor.proto}{\texttt{TensorProto}}.
\item
  \texttt{list(\textless{}type\textgreater{})}:
  \texttt{\textless{}type\textgreater{}} 列表, 其中
  \texttt{\textless{}type\textgreater{}} 是上述类型之一. 注意
  \texttt{list(list(\textless{}type\textgreater{}))} 是无效的.
\end{itemize}

权威的列表以
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def_builder.cc\#FinalizeAttr}{\texttt{op\_def\_builder.cc:FinalizeAttr}}
为准.

\paragraph{默认值和约束条件
}\label{ux9ed8ux8ba4ux503cux548cux7ea6ux675fux6761ux4ef6}

属性可能有默认值, 一些类型的属性可以有约束条件.
为了定义一个有约束条件的属性, 你可以使用下列的
\texttt{\textless{}attr-type-expr\textgreater{}} 形式:

\begin{itemize}
\tightlist
\item
  \texttt{\{\textquotesingle{}\textless{}string1\textgreater{}\textquotesingle{},\ \textquotesingle{}\textless{}string2\textgreater{}\textquotesingle{}\}}:
  属性值必须是一个字符串, 取值可以为
  \texttt{\textless{}string1\textgreater{}} 或
  \texttt{\textless{}string2\textgreater{}}.
  值的语法已经暗示了值的类型为 \texttt{string}, 已经暗示了.
  下述语句模拟了一个枚举值:
\end{itemize}

\begin{verbatim}
REGISTER_OP("EnumExample")
      .Attr("e: {'apple', 'orange'}");
\end{verbatim}

\begin{itemize}
\tightlist
\item
  \texttt{\{\textless{}type1\textgreater{},\ \textless{}type2\textgreater{}\}}:
  值是 \texttt{type} 类型, 且必须为
  \texttt{\textless{}type1\textgreater{}} 或
  \texttt{\textless{}type2\textgreater{}} 之一, 当然
  \texttt{\textless{}type1\textgreater{}} 和
  \texttt{\textless{}type2\textgreater{}} 必须都是有效的
  \href{tensorflow-zh/SOURCE/resources/dims_types.md\#data-types}{tensor
  类型}. 你无须指定属性的类型为 \texttt{type}, 而是通过 \texttt{\{...\}}
  语句给出一个类型列表. 例如, 在下面的例子里, 属性 \texttt{t}
  的类型必须为 \texttt{int32}, \texttt{float}, 或 \texttt{bool}:
\end{itemize}

\begin{verbatim}
REGISTER_OP("RestrictedTypeExample")
      .Attr("t: {int32, float, bool}");
\end{verbatim}

\begin{itemize}
\item
  这里有一些常见类型约束条件的快捷方式:
\item
  \texttt{numbertype}: 限制类型为数字类型, 即非 string 非 bool 的类型.
\item
  \texttt{realnumbertype}: 与 \texttt{numbertype} 区别是不支持复杂类型.
\item
  \texttt{quantizedtype}: 与 \texttt{numbertype} 区别是只支持量化数值
  (quantized number type).
\end{itemize}

这些类型的列表在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/types.h}{\texttt{tensorflow/core/framework/types.h}}
文件中通过函数定义 (如 \texttt{NumberTypes()}). 本例中属性 \texttt{t}
必须为某种数字类型:

\begin{verbatim}
REGISTER_OP("NumberType")
        .Attr("t: numbertype");
\end{verbatim}

对于这个 Op:

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{tf.number_type(t}\OperatorTok{=}\NormalTok{tf.int32)  }\CommentTok{# 有效}
\NormalTok{tf.number_type(t}\OperatorTok{=}\NormalTok{tf.}\BuiltInTok{bool}\NormalTok{)   }\CommentTok{# 无效}
\end{Highlighting}
\end{Shaded}

\begin{itemize}
\tightlist
\item
  \texttt{int\ \textgreater{}=\ \textless{}n\textgreater{}}:
  值必须是一个整数, 且取值大于等于 \texttt{\textless{}n\textgreater{}},
  \texttt{\textless{}n\textgreater{}} 是一个自然数.
\end{itemize}

例如, 下列 Op 注册操作指定了属性 \texttt{a} 的取值至少为 \texttt{2}.

\begin{verbatim}
REGISTER_OP("MinIntExample")
      .Attr("a: int >= 2");
\end{verbatim}

\begin{itemize}
\tightlist
\item
  \texttt{list(\textless{}type\textgreater{})\ \textgreater{}=\ \textless{}n\textgreater{}}:
  一个 \texttt{\textless{}type\textgreater{}} 类型列表,
  列表长度必须大于等于 \texttt{\textless{}n\textgreater{}}.
\end{itemize}

例如, 下面的 Op 注册操作指定属性 \texttt{a} 是一个列表,
列表中的元素类型是 \texttt{int32} 或 \texttt{float}列表长度至少为3.

\begin{verbatim}
REGISTER_OP("TypeListExample")
      .Attr("a: list({int32, float}) >= 3");
\end{verbatim}

通过添加 \texttt{=\ \textless{}default\textgreater{}} 到约束条件末尾,
给一个属性设置默认值 (使其在自动生成的代码里 变成可选属性), 如下:

\begin{verbatim}
REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");
\end{verbatim}

默认值支持的语法将在最终 GraphDef 定义的 protobuf 表示中被使用.

下面是给所有类型赋予默认值的例子:

\begin{verbatim}
REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");
\end{verbatim}

请特别注意那些类型值里面包含的
\href{tensorflow-zh/SOURCE/resources/dims_types.md\#data-types}{\texttt{DT\_*}
名称}.

\subsubsection{多态 }\label{ux591aux6001}

\hypertarget{type-polymorphism}{\paragraph{Type Polymorphism
}\label{type-polymorphism}}

对于那些可以使用不同类型输入或产生不同类型输出的 Op, 可以注册 Op
时为输入/输出类型里指定一个\protect\hyperlink{attrs}{属性}. 一般紧接着,
会为每一个支持的类型注册一个 \texttt{OpKernel}.

例如, 除了 \texttt{int32} 外, 想要 \texttt{ZeroOut} Op 支持
\texttt{float}, 注册代码如下:

\begin{verbatim}
REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: <b>T</b>")
    .Output("zeroed: <b>T</b>");
\end{verbatim}

这段 Op 注册代码现在指定了输入的类型必须为 \texttt{float} 或
\texttt{int32}, 而且 既然输入和输出制定了同样的类型 \texttt{T},
输出也同样如此.

\begin{quote}
一个命名建议:\{\#naming\} 输入, 输出, 和属性通常使用 snake\_case 命名法.
唯一的例外是属性被用作输入类型或是输入类型的一部分. 当添加到图中时,
这些属性 可以被推断出来, 因此不会出现在 Op 的函数里. 例如, 最后一个
ZeroOut 定义 生成的 Python 函数如下:
\end{quote}

\begin{Shaded}
\begin{Highlighting}[]
\KeywordTok{def} \NormalTok{zero_out(to_zero, name}\OperatorTok{=}\VariableTok{None}\NormalTok{):}
   \CommentTok{"""...}
\CommentTok{   参数:}
\CommentTok{     to_zero: 一个 `Tensor`. 必须为下列类型之一:}
\CommentTok{         `float32`, `int32`.}
\CommentTok{     name: 操作的名字 (可选).}

\CommentTok{   返回值:}
\CommentTok{     一个 `Tensor`, 类型和 `to_zero` 一样.}
\CommentTok{   """}
\end{Highlighting}
\end{Shaded}

\begin{quote}
如果输入的 \texttt{to\_zero} 是一个 \texttt{int32} 的tensor, 然后
\texttt{T} 将被自动 设置为 \texttt{int32} (实际上是 \texttt{DT\_INT32}).
那些推导出的属性的名称字母全大写 或采用驼峰命名法.

下面是一个输出类型自动推断的例子, 读者可以对比一下:
\end{quote}

\begin{verbatim}
REGISTER_OP("StringToNumber")
     .Input("string_tensor: string")
     .Output("output: out_type")
     .Attr("out_type: {float, int32}");
     .Doc(R"doc(
 Converts each string in the input Tensor to the specified numeric type.
 )doc");
\end{verbatim}

\begin{quote}
在这种情况下, 用户需要在生成的 Python 代码中指定输出类型.
\end{quote}

\begin{Shaded}
\begin{Highlighting}[]
\KeywordTok{def} \NormalTok{string_to_number(string_tensor, out_type}\OperatorTok{=}\VariableTok{None}\NormalTok{, name}\OperatorTok{=}\VariableTok{None}\NormalTok{):}
   \CommentTok{"""将输入 Tensor 中的每一个字符串转化成指定的数字类型}

\CommentTok{   参数:}
\CommentTok{     string_tensor: 一个 `string` 类型的 `Tensor`.}
\CommentTok{     out_type: 一个可选的 `tf.DType`, 取值为 `tf.float32, tf.int32`.}
\CommentTok{       默认值是 `tf.float32`.}
\CommentTok{     name: 操作的名称 (可选).}

\CommentTok{   返回值:}
\CommentTok{     一个 `out_type` 类型的 `Tensor`.}
\CommentTok{   """}
\end{Highlighting}
\end{Shaded}

\begin{verbatim}
 #include "tensorflow/core/framework/op_kernel.h"
class ZeroOutInt32Op : public OpKernel {
  // 和之前一样
};
class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction * context)
      : OpKernel(context) {}
  void Compute(OpKernelContext * context) override {
    // 获取输入 tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();
    // 创建一个输出 tensor
    Tensor * output = NULL;
    OP_REQUIRES_OK(context,
                    context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();
    // 设置输出 tensor 的所有元素为 0
    const int N = input.size();
    for (int i = 0; i &lt; N; i++) {
      output_flat(i) = 0;
    }<br/>
    // 保留第一个输入值
    if (N &gt; 0) output_flat(0) = input(0);
  }
};
// 注意, TypeConstraint<int32>("T") 意味着属性 "T" (在上面 Op 注册代码中
// 定义的) 必须是 "int32", 才能实例化.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint&lt;int32&gt;("T"),
    ZeroOutOpInt32);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);
\end{verbatim}

\begin{quote}
为了保持\protect\hyperlink{backwards-compatibility}{向后兼容性},
你在为一个 已有的 op 添加属性时,
必须指定一个\protect\hyperlink{default-values-constraints}{默认值}:
\end{quote}

\begin{verbatim}
REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")
\end{verbatim}

如果需要添加更多类型, 例如 \texttt{double}:

\begin{verbatim}
REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");
\end{verbatim}

为了避免为新增的类型写冗余的 \texttt{OpKernel} 代码, 通常可以写一个 C++
模板作为替代. 当然, 仍然需要为每一个重载版本定义一个 keneral 注册
(\texttt{REGISTER\textbackslash{}\_KERNEL\textbackslash{}\_BUILDER}
调用).

\begin{verbatim}
template <typename T>;
class ZeroOutOp : public OpKernel {
 public:
    explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
    // 获取输入 tensor
     const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();
    // 创建一个输出 tensor
      Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();
    // 设置输出 tensor 的所有元素为 0
   const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }
    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};
};<br/>
// 注意, TypeConstraint<int32>("T") 意味着属性 "T" (在上面 Op 注册代码中
// 定义的) 必须是 "int32", 才能实例化. </b>
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);
\end{verbatim}

如果有很多重载版本, 可以将注册操作通过一个宏来实现.

\begin{verbatim}
 #include "tensorflow/core/framework/op_kernel.h"
 #define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)
REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
 #undef REGISTER_KERNEL
\end{verbatim}

取决于注册 kernel 使用哪些类型,
你可能可以使用\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/register_types.h}{\texttt{tensorflow/core/framework/register\_types.h}}
提供的宏:

\begin{verbatim}
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");
template <typename T>
class ZeroOutOp : public OpKernel { ... };
 #define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
 #undef REGISTER_KERNEL
\end{verbatim}

\paragraph{列表输入和输出
}\label{ux5217ux8868ux8f93ux5165ux548cux8f93ux51fa}

除了能够使用不同类型的 tensor 作为输入或输出, Op 还支持使用多个 tensor
作为输入或输出.

在接下来的例子里, 属性 \texttt{T} 存储了一个类型\emph{列表},
并同时作为输入 \texttt{in} 和输出 \texttt{out} 的类型.
输入和输出均为指定类型的 tensor 列表. 既然输入和输出的类型均为
\texttt{T}, 它们的 tensor 数量和类型 是一致的.

\begin{verbatim}
REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");
\end{verbatim}

可以为列表中可存放的类型设置约束条件. 在下一个例子中, 输入是
\texttt{float} 和 \texttt{double} 类型的 tensor 列表. 例如, 这个 Op
可接受的 输入类型为 \texttt{(float,\ double,\ float)} 的数据,
且在此情况下, 输出类型同样 为 \texttt{(float,\ double,\ float)}.

\begin{verbatim}
REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");
\end{verbatim}

如果想要一个列表中的所有 tensor 是同一类型, 你需要写下列代码:

\begin{verbatim}
REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");
\end{verbatim}

这段代码接受 \texttt{int32} tensor 列表, 并用一个 \texttt{int} 属性
\texttt{N} 来指定列表的长度.

这也可用于\protect\hyperlink{type-polymorphism}{类型推断}.
在下一个例子中, 输入是一个 tensor 列表, 长度为 \texttt{"N"}, 类型为
\texttt{"T"}, 输出是单个 \texttt{"T"} 的 tensor:

\begin{verbatim}
REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");
\end{verbatim}

默认情况下, tensor 列表的最小长度为1. 这个约束条件可以通过
\protect\hyperlink{default-values-constraints}{为指定的属性增加一个
\texttt{"\textgreater{}="} 约束}来变更:

\begin{verbatim}
REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");
\end{verbatim}

同样的语法也适用于 \texttt{"list(type)"} 属性:

\begin{verbatim}
REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");
\end{verbatim}

\subsubsection{输入和输出 }\label{ux8f93ux5165ux548cux8f93ux51fa}

总结一下上述内容, 一个 Op 注册操作可以指定多个输入和输出:

\begin{verbatim}
REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");
\end{verbatim}

每一个输入或输出形式如下:

\begin{verbatim}
<name>: <io-type-expr>
\end{verbatim}

其中, \texttt{\textless{}name\textgreater{}} 以字母打头, 且只能由数字,
字母和下划线组成. \texttt{\textless{}io-type-expr\textgreater{}} 可以是
下列类型表达式之一:

\begin{itemize}
\tightlist
\item
  \texttt{\textless{}type\textgreater{}}, 一个合法的输入类型, 如
  \texttt{float}, \texttt{int32}, \texttt{string}.
  这可用于指定给定类型的单个 tensor.
\end{itemize}

参见\href{tensorflow-zh/SOURCE/resources/dims_types.md\#data-types}{合法
Tensor 类型列表}.

\begin{verbatim}
REGISTER_OP("BuiltInTypesExample")
      .Input("integers: int32")
      .Input("complex_numbers: scomplex64");
\end{verbatim}

\begin{itemize}
\tightlist
\item
  \texttt{\textless{}attr-type\textgreater{}},
  一个\protect\hyperlink{attrs}{属性}和一个类型 \texttt{type} 或类型列表
  \texttt{list(type)}(可能 包含类型限制).
  该语法可实现\protect\hyperlink{Polymorphism}{多态 Op}.
\end{itemize}

\begin{verbatim}
REGISTER_OP("PolymorphicSingleInput")
      .Attr("T: type")
      .Input("in: T);
REGISTER_OP("RestrictedPolymorphicSingleInput")
      .Attr("T: {int32, int64}")
      .Input("in: T);
\end{verbatim}

将属性的类型设置为 \texttt{list(type)} 将允许你接受一个序列的 tensor.

\begin{verbatim}
REGISTER_OP("ArbitraryTensorSequenceExample")
      .Attr("T: list(type)")
      .Input("in: T")
      .Output("out: T");
REGISTER_OP("RestrictedTensorSequenceExample")
      .Attr("T: list({int32, int64})")
      .Input("in: T")
      .Output("out: T");
\end{verbatim}

注意, 输入和输出均为 \texttt{T}, 意味着输入和输出的类型与数量均相同.

\begin{itemize}
\tightlist
\item
  \texttt{\textless{}number\textgreater{}\ *\ \textless{}type\textgreater{}},
  一组拥有相同类型的 tensor, \texttt{\textless{}number\textgreater{}}
  是一个 \texttt{int} 类型属性的名称.
  \texttt{\textless{}type\textgreater{}}
  可以是\href{tensorflow-zh/SOURCE/resources/dims_types.md\#data-types}{一个类似于
  \texttt{int32} 和 \texttt{float} 的特定类型}, 或者一个 \texttt{type}
  类型属性的名字. 前者的例子如下, 该例子接受一个 \texttt{int32} tensor
  列表作为 Op 输入:
\end{itemize}

\begin{verbatim}
REGISTER_OP("Int32SequenceExample")
      .Attr("NumTensors: int")
      .Input("in: NumTensors * int32")
\end{verbatim}

后者的例子如下, 该例子接受一个泛型 tensor 列表作为 Op 输入:

\begin{verbatim}
REGISTER_OP("SameTypeSequenceExample")
      .Attr("NumTensors: int")
      .Attr("T: type")
      .Input("in: NumTensors * T")
\end{verbatim}

\begin{itemize}
\tightlist
\item
  Tensor 的引用表示为 \texttt{Ref(\textless{}type\textgreater{})}, 其中
  \texttt{\textless{}type\textgreater{}} 是上述类型之一.
\end{itemize}

\begin{quote}
一个命名建议: 当使用属性表示一个输入的类型时, 该类型可以被推断出来.
实现该特性, 将需要推断 的类型用大写名称表示 (如 \texttt{T} 或
\texttt{N}), 其它的输入, 输出, 和属性像使用函数参数一样使用这些
大写名称.
参见之前的\protect\hyperlink{naming}{命名建议}章节查看更多细节.
\end{quote}

更多细节参见
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def_builder.h}{\texttt{tensorflow/core/framework/op\_def\_builder.h}}.

\subsubsection{向后兼容性 }\label{ux5411ux540eux517cux5bb9ux6027}

通常, 对规范的改变必须保持向后兼容性: Op 使用新规范后,
需保证使用旧规范构造的序列化 GraphDef 仍能正确工作.

下面是几种保持向后兼容性的方式:

\begin{enumerate}
\def\labelenumi{\arabic{enumi}.}
\tightlist
\item
  任何添加到 Op 的新属性必须有默认值, 且默认值下的行为有明确定义.
  将一个非多态的操作变为多态操作, 你\emph{必须}为新的类型属性赋予默认值,
  以保持原始的函数签名. 例如, 有如下操作:
\end{enumerate}

\begin{verbatim}
REGISTER_OP("MyGeneralUnaryOp")
       .Input("in: float")
       .Output("out: float");
\end{verbatim}

可以通过下述方式将其变为多态, 且保持向后兼容性:

\begin{verbatim}
REGISTER_OP("MyGeneralUnaryOp")
       .Input("in: T")
       .Output("out: T")
       .Attr("T: numerictype = float");
\end{verbatim}

1.放宽一个属性的约束条件是安全的. 例如, 你可以将
\texttt{\{int32,\ int64\}} 变为 \texttt{\{int32,\ int64,\ float\}},
或者, 将 \texttt{\{"apple",\ "orange"\}} 变为
\texttt{\{"apple",\ "banana",\ "orange"\}}.

2.通过给 Op 名称添加一些项目中唯一的标识作为前缀, 来为新建的 Op
添加命名空间. 命名空间 可以预防你的 Op 与 TensorFlow 未来版本里的内置 Op
产生命名冲突.

3.超前计划! 尝试着去预测 Op 未来的的用途, 超前设计, 毕竟,
一些签名的变更无法保证兼容性 (例如, 增加新的输入,
或将原来的单元素输入变成一个列表).

如果不能以兼容的方式改变一个操作, 那就创建一个全新的操作,
来实现所需功能.

\subsection{GPU 支持 }\label{gpu-ux652fux6301}

你可以实现不同的 OpKernel, 将其中之一注册到 GPU, 另一个注册到 GPU,
正如\protect\hyperlink{Polymorphism}{为不同的类型注册 kernel}一样.
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/}{\texttt{tensorflow/core/kernels/}}
中有一些 GPU 支持的例子. 注意, 一些 kernel 的 CPU 版本位于 \texttt{.cc}
文件, GPU 版本位于 \texttt{\_gpu.cu.cc} 文件, 共享的代码位于 \texttt{.h}
文件.

例如,
\href{tensorflow-zh/SOURCE/api_docs/python/array_ops.md\#pad}{\texttt{pad}
op} 除了 GPU kernel 外的其它代码 均在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/pad_op.cc}{\texttt{tensorflow/core/kernels/pad\_op.cc}}
中. GPU kernel 位于
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/pad_op_gpu.cu.cc}{\texttt{tensorflow/core/kernels/pad\_op\_gpu.cu.cc}},
共享的一个模板类代码定义在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/pad_op.h}{\texttt{tensorflow/core/kernels/pad\_op.h}}.
需要注意的事情是, 即使使用 \texttt{pad} 的 GPU 版本时, 仍然需要将
\texttt{"paddings"} 输入放置到内存中. 为了实现这一点,
将输入或输出标记为必须保存在内存中, 为 kernel 注册一个
\texttt{HostMemory()} 调用. 如下:

\begin{verbatim}
 #define REGISTER_GPU_KERNEL(T)                         \
REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)
\end{verbatim}

\subsection{使用 Python 实现梯度
}\label{ux4f7fux7528-python-ux5b9eux73b0ux68afux5ea6}

给定一个 Op 组成的图, TensorFlow 使用自动微分 (反向传播) 来添加新的 Op
以表示梯度运算, 同时 不影响已有的 Op
(参见\href{tensorflow-zh/SOURCE/api_docs/python/train.md\#gradient-computation}{梯度运算}).
为了使自动微分能够与新的 Op 协同工作, 必须注册一个梯度函数, 从 Op
的输入计算梯度, 并返回代表 梯度值的输出.

数学上, 如果一个 Op 计算 \textbackslash{}(y = f(x)\textbackslash{}),
注册的梯度 Op 通过以下链式法则, 将 \textbackslash{}(\partial /
\partial y\textbackslash{}) 的梯度运算转化为 \textbackslash{}(\partial /
\partial x\textbackslash{}) 的梯度运算.

\[\frac{\partial}{\partial x}
    = \frac{\partial}{\partial y} \frac{\partial y}{\partial x}
    = \frac{\partial}{\partial y} \frac{\partial f}{\partial x}.\]

在 \texttt{ZeroOut} 的例子中, 输入中只有一个项会影响输出, 所以,
代表输入的梯度值的 tensor 也只有 一个输入项. 如下所示:

\begin{Shaded}
\begin{Highlighting}[]
\ImportTok{from} \NormalTok{tensorflow.python.framework }\ImportTok{import} \NormalTok{ops}
\ImportTok{from} \NormalTok{tensorflow.python.ops }\ImportTok{import} \NormalTok{array_ops}
\ImportTok{from} \NormalTok{tensorflow.python.ops }\ImportTok{import} \NormalTok{sparse_ops}

\AttributeTok{@ops.RegisterGradient}\NormalTok{(}\StringTok{"ZeroOut"}\NormalTok{)}
\KeywordTok{def} \NormalTok{_zero_out_grad(op, grad):}
  \CommentTok{"""`zero_out` 的梯度.}

\CommentTok{  参数:}
\CommentTok{    op: 欲进行微分的 `zero_out` `操作`, 可以用于获取原始 Op 的输入和输出.}
\CommentTok{    grad: 代表 `zero_out` 输出的梯度 Op.}

\CommentTok{  返回:}
\CommentTok{    代表输入 `zero_out` 的微分.}
\CommentTok{  """}
  \NormalTok{to_zero }\OperatorTok{=} \NormalTok{op.inputs[}\DecValTok{0}\NormalTok{]}
  \NormalTok{shape }\OperatorTok{=} \NormalTok{array_ops.shape(to_zero)}
  \NormalTok{index }\OperatorTok{=} \NormalTok{array_ops.zeros_like(shape)}
  \NormalTok{first_grad }\OperatorTok{=} \NormalTok{array_ops.reshape(grad, [}\OperatorTok{-}\DecValTok{1}\NormalTok{])[}\DecValTok{0}\NormalTok{]}
  \NormalTok{to_zero_grad }\OperatorTok{=} \NormalTok{sparse_ops.sparse_to_dense(index, shape, first_grad, }\DecValTok{0}\NormalTok{)}
  \ControlFlowTok{return} \NormalTok{[to_zero_grad]  }\CommentTok{# 单个 Tensor 的列表, 既然只有一个输入}
\end{Highlighting}
\end{Shaded}

使用
\href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#RegisterGradient}{\texttt{ops.RegisterGradient}}
注册梯度函数需要注意的一些细节:

\begin{itemize}
\item
  对于仅有一个输出的 Op, 梯度函数使用
  \href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#Operation}{\texttt{Operation}}
  \texttt{op} 和一个
  \href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#Tensor}{\texttt{Tensor}}
  \texttt{grad} 作为参数, 并从
  \href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#Operation.inputs}{\texttt{op.inputs{[}i{]}}},
  \href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#Operation.outputs}{\texttt{op.outputs{[}i{]}}},
  和 \texttt{grad} 构建新的 Op. 属性的信息可以通过
  \href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#Operation.get_attr}{\texttt{op.get\_attr}}
  获取.
\item
  如果 Op 有多个输出, 梯度函数将使用 \texttt{op} 和 \texttt{grads}
  作为参数, 其中, \texttt{grads} 是一个 梯度 Op 的列表,
  为每一个输出计算梯度. 梯度函数的输出必须是一个 \texttt{Tensor}
  对象列表, 对应到 每一个输入的梯度.
\item
  如果没有为一些输入定义梯度, 譬如用作索引的整型, 这些输入返回的梯度为
  \texttt{None}. 举一个例子, 如果一个 Op 的输入为一个浮点数 tensor
  \texttt{x} 和一个整型索引 \texttt{i}, 那么梯度函数将返回
  \texttt{{[}x\_grad,\ None{]}}.
\item
  如果梯度对于一个 Op 来说毫无意义, 使用
  \texttt{ops.NoGradient("OpName")} 禁用自动差分.
\end{itemize}

注意当梯度函数被调用时, 作用的对象是数据流图中的 Op, 而不是 tensor
数据本身. 因此, 只有在图运行时, 梯度运算才会被其它 tensorflow Op
的执行动作所触发.

\subsection{在 Python 中实现一个形状函数
}\label{ux5728-python-ux4e2dux5b9eux73b0ux4e00ux4e2aux5f62ux72b6ux51fdux6570}

TensorFlow Python API 有一个 ``形状推断'' 功能, 可以不执行图就获取
tensor 的形状信息. 形状推断功能藉由每一个 Op 类型注册的 ``形状函数''
来支持, 该函数有两个规则: 假设所有输入的 形状必须是兼容的,
以及指定输出的形状. 一个形状函数以一个
\href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#Operation}{\texttt{Operation}}
作为输入, 返回一个
\href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#TensorShape}{\texttt{TensorShape}}
对象列表 (每一个输出一个对象). 使用
\href{tensorflow-zh/SOURCE/api_docs/python/framework.md\#RegisterShape}{\texttt{tf.RegisterShape}
装饰器} 注册形状函数. 例如,
\protect\hyperlink{defineux5finterface}{上文定义的 \texttt{ZeroOut} Op}
的形状函数如下:

\begin{Shaded}
\begin{Highlighting}[]
\AttributeTok{@tf.RegisterShape}\NormalTok{(}\StringTok{"ZeroOut"}\NormalTok{):}
\KeywordTok{def} \NormalTok{_zero_out_shape(op):}
  \CommentTok{"""ZeroOut Op 的形状函数.}

\CommentTok{  这是 ZeroOut 形状函数的无约束版本, 为每一个输出产生的形状和对应的输入一样. }
\CommentTok{  """}
  \ControlFlowTok{return} \NormalTok{[op.inputs[}\DecValTok{0}\NormalTok{].get_shape()]}
\end{Highlighting}
\end{Shaded}

一个形状函数也可以约束输入的形状. 下面是
\protect\hyperlink{Validation}{ZeroOut 形状函数的 vector 输入约束}版本:

\begin{Shaded}
\begin{Highlighting}[]
\AttributeTok{@tf.RegisterShape}\NormalTok{(}\StringTok{"ZeroOut"}\NormalTok{):}
\KeywordTok{def} \NormalTok{_zero_out_shape(op):}
  \CommentTok{"""ZeroOut Op 的形状函数.}

\CommentTok{  这是 ZeroOut 形状函数的约束版本, 要输入的 rank 必须是 1 (即使一个 vector).}
\CommentTok{  """}
  \NormalTok{input_shape }\OperatorTok{=} \NormalTok{op.inputs[}\DecValTok{0}\NormalTok{].get_shape().with_rank(}\DecValTok{1}\NormalTok{)}
  \ControlFlowTok{return} \NormalTok{[input_shape]}
\end{Highlighting}
\end{Shaded}

如果 Op 是\protect\hyperlink{Polymorphism}{多输入的多态 Op},
使用操作的属性来决定需要检查的形状数量:

\begin{verbatim}
@tf.RegisterShape("IntListInputExample")
def _int_list_input_example_shape(op):
  """ "IntListInputExample" Op 的形状函数.

  所有的输入和输出是同大小的矩阵.
  """
  output_shape = tf.TensorShape(None)
  for input in op.inputs:
    output_shape = output_shape.merge_with(input.get_shape().with_rank(2))
  return [output_shape]
\end{verbatim}

既然形状推断是一个可选的特性, 且 tensor 的形状可能动态变化,
形状函数必须足够健壮, 能够处理任意 输入形状信息缺失的情形.
\href{tensorflow-zh/SOURCE/api_docs/python/framework.md}{\texttt{merge\_with}}
方法能够帮助 调用者判断两个形状是否是一样的, 即使两个形状的信息不全,
该函数同样有效.
所有的\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/}{标准
Python Op} 的形状函数都已经定义好了, 并且已经有很多不同的使用示例.

\begin{quote}
原文：\href{http://www.tensorflow.org/how_tos/adding_an_op/index.html\#adding-a-new-op}{Adding
a New Op} 翻译：{[}@doc001{]}(https://github.com/PFZheng)
校对：{[}@ZHNathanielLee{]}(https://github.com/ZHNathanielLee)
\end{quote}


